{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# PASCAL VOC - Light-Weight Refinenet\n",
    "\n",
    "## 20 semantic classes + background\n",
    "\n",
    "### Light-Weight RefineNet based on ResNet-50/101/152 and MobileNet-v2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import six\n",
    "import sys\n",
    "sys.path.append('../../')\n",
    "\n",
    "from models.mobilenet import mbv2\n",
    "from models.resnet import rf_lw50, rf_lw101, rf_lw152"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.helpers import prepare_img"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "import glob\n",
    "\n",
    "import cv2\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "from PIL import Image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "cmap = np.load('../../utils/cmap.npy')\n",
    "has_cuda = torch.cuda.is_available()\n",
    "img_dir = '../imgs/VOC/'\n",
    "imgs = glob.glob('{}*.jpg'.format(img_dir))\n",
    "n_classes = 21"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialise models\n",
    "model_inits = { \n",
    "    'rf_lw50_voc'   : rf_lw50, # key / constructor\n",
    "    'rf_lw101_voc'  : rf_lw101,\n",
    "    'rf_lw152_voc'  : rf_lw152,\n",
    "    'rf_lwmbv2_voc': mbv2\n",
    "    }\n",
    "\n",
    "models = dict()\n",
    "for key,fun in six.iteritems(model_inits):\n",
    "    net = fun(n_classes, pretrained=True).eval()\n",
    "    if has_cuda:\n",
    "        net = net.cuda()\n",
    "    models[key] = net"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/vladimir/Documents/venvs/darts/lib/python3.6/site-packages/torch/nn/modules/upsampling.py:122: UserWarning: nn.Upsampling is deprecated. Use nn.functional.interpolate instead.\n",
      "  warnings.warn(\"nn.Upsampling is deprecated. Use nn.functional.interpolate instead.\")\n"
     ]
    }
   ],
   "source": [
    "# Figure 2 from the paper\n",
    "n_cols = len(models) + 2 # 1 - for image, 1 - for GT\n",
    "n_rows = len(imgs)\n",
    "\n",
    "plt.figure(figsize=(16, 12))\n",
    "idx = 1\n",
    "\n",
    "with torch.no_grad():\n",
    "    for img_path in imgs:\n",
    "        img = np.array(Image.open(img_path))\n",
    "        msk = cmap[np.array(Image.open(img_path.replace('jpg', 'png')))]\n",
    "        orig_size = img.shape[:2][::-1]\n",
    "        \n",
    "        img_inp = torch.tensor(prepare_img(img).transpose(2, 0, 1)[None]).float()\n",
    "        if has_cuda:\n",
    "            img_inp = img_inp.cuda()\n",
    "        \n",
    "        plt.subplot(n_rows, n_cols, idx)\n",
    "        plt.imshow(img)\n",
    "        plt.title('img')\n",
    "        plt.axis('off')\n",
    "        idx += 1\n",
    "        \n",
    "        plt.subplot(n_rows, n_cols, idx)\n",
    "        plt.imshow(msk)\n",
    "        plt.title('gt')\n",
    "        plt.axis('off')\n",
    "        idx += 1\n",
    "        \n",
    "        for mname, mnet in six.iteritems(models):\n",
    "            segm = mnet(img_inp)[0].data.cpu().numpy().transpose(1, 2, 0)\n",
    "            segm = cv2.resize(segm, orig_size, interpolation=cv2.INTER_CUBIC)\n",
    "            segm = cmap[segm.argmax(axis=2).astype(np.uint8)]\n",
    "            \n",
    "            plt.subplot(n_rows, n_cols, idx)\n",
    "            plt.imshow(segm)\n",
    "            plt.title(mname)\n",
    "            plt.axis('off')\n",
    "            idx += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
